[NOIP2016]天天爱跑步

2019-11-14
NOIP

题意

在树上有一些简单链,问每个点是多少条链的第\(w_j\)个点

题解

把链分为上行的和下行的两条

在上行链上,造成贡献的必要条件是\(w_j+dep_j=dep_u\),u为j子树中的点

在下行链上,造成贡献的必要条件是\(w_j-dep_j=l-dep_v\),v为j子树中的点

dfs,顺便关于\(w_j+dep_j\)和\(w_j-dep[j]\)开桶统计即可

注意点1,当一条链两个端点的lca恰好在这条链有贡献,会被算两遍,要先减掉

注意点2,如果链的起点和终点都在子树中,lca以上就不能算贡献

因此,链的起点和终点的贡献要在lca处减掉

调试记录

注意点1,2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <cstdio>
#include <algorithm>
#include <vector>
const int maxn = 3e5 + 5;
using namespace std;
struct E{
int to, nxt;
}e[maxn << 1];
int head[maxn], tot = 0;
void addedge(int u, int v){
e[++tot].to = v, e[tot].nxt = head[u];
head[u] = tot;
}
int f[maxn][25], dep[maxn];
void dfs1(int cur, int fa){
f[cur][0] = fa; dep[cur] = dep[fa] + 1;
for (int i = 1; (1 << i) <= dep[cur]; i++)
f[cur][i] = f[f[cur][i - 1]][i - 1];
for (int i = head[cur]; i; i = e[i].nxt)
if (e[i].to != fa) dfs1(e[i].to, cur);
}
int LCA(int u, int v){
if (dep[u] > dep[v]) swap(u, v);
for (int i = 20; i >= 0; i--)
if (dep[v] - (1 << i) >= dep[u]) v = f[v][i];
if (u == v) return u;
for (int i = 20; i >= 0; i--)
if (f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
return f[u][0];
}
int s[maxn], ds[maxn << 1], dt[maxn << 2], res[maxn];
vector <int> t[maxn], tt[maxn], ss[maxn]; int w[maxn];
void dfs2(int cur, int fa){
int t1 = ds[w[cur] + dep[cur]], t2 = dt[w[cur] - dep[cur] + maxn];
ds[dep[cur]] += s[cur];
for (int j = 0; j < t[cur].size(); j++)
dt[t[cur][j] + maxn]++;
for (int i = head[cur]; i; i = e[i].nxt){
int v = e[i].to;
if (v == fa) continue;
dfs2(v, cur);
}
res[cur] += ds[w[cur] + dep[cur]] - t1 + dt[w[cur] - dep[cur] + maxn] - t2;
for (int j = 0; j < ss[cur].size(); j++)
ds[dep[ss[cur][j]]]--;
for (int j = 0; j < tt[cur].size(); j++)
dt[tt[cur][j] + maxn]--;
}
int n, m;
int main(){
scanf("%d%d", &n, &m);
for (int u, v, i = 1; i < n; i++){
scanf("%d%d", &u, &v);
addedge(u, v); addedge(v, u);
} dfs1(1, 0);
// printf("%d\n", LCA(2, 5));
for (int i = 1; i <= n; i++) scanf("%d", w + i);
for (int u, v, i = 1; i <= m; i++){
scanf("%d%d", &u, &v); int lca = LCA(u, v);
s[u]++; if (dep[lca] + w[lca] == dep[u]) res[lca]--;
ss[lca].push_back(u);
t[v].push_back(dep[u] - 2 * dep[lca]);
tt[lca].push_back(dep[u] - 2 * dep[lca]);
} dfs2(1, 0);
for (int i = 1; i <= n; i++) printf("%d ", res[i]);
return 0;
}